K-NN Python
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
def k_nearest_neighbor( the_k, dataset, predictd_date):
distance_temp = []
for group in dataset:
for data in dataset[group]:
c_distance = np.linalg.norm(np.array(predictd_date) - np.array(data))
distance_temp.append([c_distance, group])
sorted_group = [distance[1] for distance in sorted(distance_temp)]
top_k_nearest = sorted_group[:the_k]
result = Counter(top_k_nearest).most_common(1)[0][0]
prediction = Counter(top_k_nearest).most_common(1)[0][1] *1.0/the_k
return result, prediction
if __name__=='__main__':
dataset = {'black':[ [1,2], [2,3], [3,1] ], 'red':[ [6,5], [7,7], [8,6] ]}
new_features = [ [6.5, 7.2], [3.1, 3.2] ]
#plot all dataset
Figure = plt.figure()
Figure_sub = Figure.add_subplot(111)
for group in dataset:
for data in dataset[group]:
Figure_sub.scatter(data[0], data[1], s=50, color=group)
Result, prediction = k_nearest_neighbor(3, dataset, new_features)
Figure_sub.scatter(new_features[0], new_features[1], s=60, color=Result)
plt.show()
Author: dctrl
Created: 2018-11-08 Thu 19:26
Validate